package cn.edu.gznu.wecampus.core;

import java.io.Serializable;
import java.lang.reflect.Field;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.persistence.Column;
import javax.persistence.JoinColumn;

import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.After;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

@Aspect
@Component
public class IRepositoryAspect {
	
	private Map<String, Map<String, Field>> classFieldMap = new HashMap<String, Map<String, Field>>();
	
	@Autowired
	private UserContext userContext;
	
	@Before(value = "(execution(* IRepository+.save(..)) || execution(* IRepository+.saveAndFlush(..))) && args(entity)", argNames = "entity")
	public void saveBefore(JoinPoint jp, JpaEntity<? extends Serializable> entity) {
		if(entity.isNew()) {
			entity.setCreateTime(new Date());
			entity.setCreateUser(userContext.getUserId());
		}
		entity.setUpdateTime(new Date());
		entity.setUpdateUser(userContext.getUserId());
	}
	
	@Before(value = "execution(* IRepository+.saveAll(..)) && args(entities)", argNames = "entities")
	public void saveAllBefore(JoinPoint jp, List<JpaEntity<? extends Serializable>> entities) {
		entities.forEach(entity -> {
			if(entity.isNew()) {
				entity.setCreateTime(new Date());
				entity.setCreateUser(userContext.getUserId());
			}
			entity.setUpdateTime(new Date());
			entity.setUpdateUser(userContext.getUserId());
		});
	}
	
	@After(value = "(execution(* IRepository+.save(..)) || execution(* IRepository+.saveAndFlush(..))) && args(entity)", argNames = "entity")
	public void saveAfter(JoinPoint jp, JpaEntity<? extends Serializable> entity) throws IllegalArgumentException, IllegalAccessException {
		Class<?> clazz = entity.getClass();
		Field[] fields = clazz.getDeclaredFields();
		Map<String, Field> fieldMap = this.classFieldMap.get(clazz.getName());
		if(fieldMap == null) {
			Set<String> joinColumnNameSet = new HashSet<String>();
			for(Field field : fields) {
				JoinColumn annotation = field.getAnnotation(JoinColumn.class);
				if(annotation != null) {
					String columnName = annotation.name();
					joinColumnNameSet.add(columnName);
				}
			}
			fieldMap = new HashMap<String, Field>();
			for(Field field : fields) {
				Column annotation = field.getAnnotation(Column.class);
				if(annotation != null) {
					String columnName = annotation.name();
					if(joinColumnNameSet.contains(columnName)) {
						fieldMap.put(columnName, field);
					}
				}
			}
			this.classFieldMap.put(clazz.getName(), fieldMap);
		}
		for(Field field : fields) {
			JoinColumn annotation = field.getAnnotation(JoinColumn.class);
			if(annotation != null) {
				String columnName = annotation.name();
				Field columnField = fieldMap.get(columnName);
				if(columnField != null) {
					boolean accessable = field.isAccessible();
					field.setAccessible(true);
					Object joinColumnEntity = field.get(entity);
					if(joinColumnEntity == null) {
						boolean access = columnField.isAccessible();
						columnField.setAccessible(true);
						columnField.set(entity, null);
						columnField.setAccessible(access);
					} else if(joinColumnEntity instanceof JpaEntity) {
						@SuppressWarnings("unchecked")
						Serializable id = ((JpaEntity<Serializable>) joinColumnEntity).getId();
						boolean access = columnField.isAccessible();
						columnField.setAccessible(true);
						columnField.set(entity, id);
						columnField.setAccessible(access);
					}
					field.setAccessible(accessable);
				}
			}
		}
	}
}
